import torch
import os
import numpy as np
import cv2
from torchvision.transforms import functional
import random
from aip import AipFace
import time
import tkinter as tk
from tkinter import filedialog
import shutil
import socket
from predict_once1 import face_iou_cal, face_detection, predict_faces
from PIL import Image, ImageDraw, ImageFont

addr = ('192.168.1.53', 20020)    # 本机的ip地址和端口号

""" 你的 APPID AK SK """
APP_ID = '22912073'
API_KEY = '3zbPMNiqWOrsD5BspgX1pBoR'
SECRET_KEY = 'nxSoiCZnSdOV9DVTPhHrlMvAYvTHZNa4'
# aip_face对象
aipFace = AipFace(APP_ID, API_KEY, SECRET_KEY)

faces_name_dict = {
    'Paulmale': 'Paul Male',
    'Danielmale': 'Daniel Male',
    'Fishermale': 'Fisher Male',
    'Jackmale': 'Jack Male',
    'Kevinmale': 'Kevin Male',
    'lilyfemale': 'lily Female',
    'Rosefemale': 'Rose Female',
    'Maryfemale': 'Mary Female',
    'Michaelmale': 'Michael Male',
    'stevenmale': 'Steven Male',
    'Jamesmale': 'James Male'
}


# 终止标志位
FLAG = b'$END$'


# 要求地址，是IP+端口的形式
# 这里，BUF是1024
def recv_txt_TCP(addr, BUF=1024):
    sk = socket.socket()  # socket对象
    # print('.')
    sk.bind(addr)  # 绑定套接字到本机的网络地址
    sk.listen(10)  # 10表示最大连接数
    client, cli_addr = sk.accept()  # 接受，这里会卡住不往下运行，直到接收到数据
    data = client.recv(1024)
    data = data.decode()
    return data


def predict_once(model, img):
    # 将模型切换成预测模式
    model.eval()

    img_tensor = functional.to_tensor(img)

    with torch.no_grad():
        # 下方是原注释
        '''
        prediction形如：
        [{'boxes': tensor([[1492.6672,  238.4670, 1765.5385,  315.0320],
        [ 887.1390,  256.8106, 1154.6687,  330.2953]], device='cuda:0'), 
        'labels': tensor([1, 1], device='cuda:0'), 
        'scores': tensor([1.0000, 1.0000], device='cuda:0')}]
        '''
        device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
        prediction = model([img_tensor.to(device)])

    # 先将预测结果打印出来看一下
    # print(prediction)

    return prediction


def random_color():
    b = random.randint(50, 200)
    g = random.randint(50, 200)
    r = random.randint(50, 200)
    return (b, g, r)

def main(args):
    # 假装加载模型
    print('Loading Selected Model...\n')
    time.sleep(0.5)
    print('Loading Selected Model...Done\n')

    # 任意位置打开图片，复制到predic_boost文件夹里面
    print('Please select A picture:\n')
    root = r'dataset'
    temp_root = os.path.join(root, 'temp')
    prediction_root = os.path.join(root, 'prediction')
    prediction_txt_path = os.path.join(prediction_root, 'PREDICTION_DOCUMENT.txt')
    if not os.path.exists(temp_root):
        os.mkdir(temp_root)
    else:
        for file_name in os.listdir(temp_root):
            os.remove(os.path.join(temp_root, file_name))
    boost_root = r'predict_boost'
    if not os.path.exists(boost_root):
        os.mkdir(boost_root)
    tk1 = tk.Tk()
    tk1.withdraw()
    img_path = filedialog.askopenfilename()
    src = img_path
    dst = os.path.join(temp_root, 'origin_img.jpg')
    img_path = dst  # 这张图片路径
    shutil.copy(src, dst)

    # 准备prediction文档文件
    with open(prediction_txt_path, 'w', encoding='utf-8') as file:
        file.write('PREDICTION:\n')

    # cv2读取图片
    src_img = cv2.imread(img_path)
    # 先显示出来
    cv2.namedWindow('img')
    time.sleep(0.1)
    cv2.imshow('img', src_img)
    print('Wait key to predict once...\n')
    cv2.waitKey()
    print('Predicting...\n')

    # 人脸检测
    faces = []
    times_idx = 0
    next_face_img_path = img_path
    while True:
        times_idx += 1
        # next_path已经准备好，因此直接检测
        time_start = time.time()
        faces_once = predict_faces(next_face_img_path)

        # 判断是否没有人脸了
        if not faces_once:  # 如果检测不到人脸了，就退出
            break

        # 拼接人脸列表
        faces.extend(faces_once)

        # 涂色
        t_face_img = cv2.imread(next_face_img_path)
        for face in faces_once:
            location = face['location']
            left = location['left']
            top = location['top']
            width = location['width']
            height = location['height']
            cv2.rectangle(t_face_img, (int(left - 5), int(top - 5)),
                          (int(left + width + 5), int(top + height + 5)),
                          color=(0, 0, 0),
                          thickness=-1)  # 涂色

        # 保存新的图片，准备下一张图片的路径
        new_face_img_name = 'face_search_temp' + str(times_idx) + '.jpg'
        next_face_img_path = os.path.join(temp_root, new_face_img_name)
        cv2.imwrite(next_face_img_path, t_face_img, [int(cv2.IMWRITE_PNG_COMPRESSION), 0])

        # 记录时间，两次人脸检测
        time_end = time.time()
        if time_end - time_start < 0.5:
            time.sleep(time_end - time_start)

    # 人脸搜索阈值检测，如果匹配得分过低就舍弃这个人脸
    face_search_th = 65
    # faces = filter(lambda x: x['score'] < face_search_th, faces)  # 使用filter函数，筛除列表中不足65分的人脸
    for idx in range(len(faces) - 1, -1, -1):  # 倒序删除
        if faces[idx]['score'] < face_search_th:
            faces.pop(idx)

    # 人脸性别识别
    time.sleep(0.5)  # 直接补0.5s的时间，防止人脸识别冲突
    gender_faces = []
    times_idx = 0
    next_face_img_path = img_path
    while True:
        # print(next_face_img_path)
        times_idx += 1
        # next_path已经准备好，因此直接检测
        time_start = time.time()
        faces_once = face_detection(next_face_img_path)

        # 判断是否没有人脸了
        if not faces_once:  # 如果检测不到人脸了，就退出
            break

        # 拼接人脸列表
        gender_faces.extend(faces_once)

        # 涂色
        t_face_img = cv2.imread(next_face_img_path)
        for face in faces_once:
            location = face['location']
            left = location['left']
            top = location['top']
            width = location['width']
            height = location['height']
            cv2.rectangle(t_face_img, (int(left - 15), int(top - 15)),
                          (int(left + width + 15), int(top + height + 15)),
                          color=(0, 0, 0),
                          thickness=-1)  # 涂色

        # 保存新的图片，准备下一张图片的路径
        new_face_img_name = 'face_detection_temp' + str(times_idx) + '.jpg'
        next_face_img_path = os.path.join(temp_root, new_face_img_name)
        cv2.imwrite(next_face_img_path, t_face_img, [int(cv2.IMWRITE_PNG_COMPRESSION), 0])

        # 记录时间，两次人脸检测
        time_end = time.time()
        if time_end - time_start < 0.5:
            time.sleep(time_end - time_start)

    # 将性别识别出来的人脸与之前的逐一比对，然后删去多余的
    face_iou_th = 0.9
    for idx in range(len(gender_faces) - 1, -1, -1):  # 倒序循环解决列表删除问题
        gender_face = gender_faces[idx]
        for face in faces:
            pass
            iou = face_iou_cal(face, gender_face)
            # print(iou)
            if face_iou_cal(face, gender_face) > face_iou_th:  # 如果iou超过阈值，则舍弃gender-face
                gender_faces.pop(idx)
    # print(faces)

    # 显示时间 队伍名称
    cv2.putText(src_img, time.strftime('%F--%Hh%Mm%Ss'), (0, 25), fontFace=cv2.FONT_HERSHEY_COMPLEX,
                fontScale=1, color=(255, 0, 0), thickness=1)
    # cv2.putText(src_img, 'NNU--识别1队', (0, 50), fontFace=cv2.FONT_HERSHEY_COMPLEX,
    #             fontScale=1, color=(255, 0, 0), thickness=1)
    font_path = 'simsun.ttc'
    chinese_font = ImageFont.truetype(font_path, 40)
    team_name = 'NNU识别1队'
    img_pil = Image.fromarray(src_img)
    draw = ImageDraw.Draw(img_pil)
    draw.text((0, 50), team_name, font=chinese_font, fill=(255, 0, 0, 0))
    src_img = np.array(img_pil)

    # tcp接收txt，画物品框
    boxes = recv_txt_TCP(addr)
    box_list = boxes.split('/')
    box_list_ori = [box for box in box_list if box != '']  # 原始接受列表
    # print(box_list_ori)
    # 拆分字典，拆分坐标与标签
    for box_ori in box_list_ori:
        color = random_color()
        label = box_ori.split('+')[-1]
        x1 = box_ori.split(' ')[0]
        y1 = box_ori.split(' ')[1]
        x2 = box_ori.split(' ')[2]
        y2 = box_ori.split(' ')[3].split('+')[0]
        cv2.rectangle(src_img, (int(x1), int(y1)), (int(x2), int(y2)), color=color, thickness=2)
        cv2.rectangle(src_img, (int(x1), int(y2)), (int(x2), int(y2) + 20), color=color, thickness=-1)
        cv2.putText(src_img, label, (int(x1), int(y2) + 16), fontFace=cv2.FONT_HERSHEY_SIMPLEX,
                    fontScale=0.5,
                    color=(0, 0, 0))

        # 写入txt文件
        with open(prediction_txt_path, 'a', encoding='utf-8') as f:
            f.write('boxes:\t')
            f.write('x1: ' + str(int(x1)) + '\t')
            f.write('y1: ' + str(int(y1)) + '\t')
            f.write('x2: ' + str(int(x2)) + '\t')
            f.write('y2: ' + str(int(y2)) + '\t')
            f.write('label: ' + str(label) + '\t')
            f.write('\n')

    # 画人脸检测框(gender)
    for gender_face in gender_faces:
        # 取信息
        location = gender_face['location']
        score = gender_face['score']
        gender = gender_face['gender']
        if gender == 'male':
            gender = 'Male'
        if gender == 'female':
            gender = 'Female'

        left = location['left']
        top = location['top']
        width = location['width']
        height = location['height']

        color = random_color()
        cv2.rectangle(src_img, (int(left), int(top)), (int(left + width), int(top + height)), color=color,
                      thickness=2)  # boxes
        cv2.rectangle(src_img, (int(left), int(top + height)), (int(left + width), int(top + height + 40)),
                      color=color,
                      thickness=-1)  # 底色
        cv2.putText(src_img, 'Stranger', (int(left), int(top + height + 16)), fontFace=cv2.FONT_HERSHEY_TRIPLEX,
                    fontScale=0.5, color=(0, 0, 0), thickness=1)  # user_id
        cv2.putText(src_img, gender, (int(left), int(top + height + 32)), fontFace=cv2.FONT_HERSHEY_TRIPLEX,
                    fontScale=0.5, color=(0, 0, 0), thickness=1)  # gender

        # 保存txt
        with open(prediction_txt_path, 'a', encoding='utf-8') as f:
            f.write('boxes:\t')
            f.write('x1: ' + str(int(left)) + '\t')
            f.write('y1: ' + str(int(top)) + '\t')
            f.write('x2: ' + str(int(left + width)) + '\t')
            f.write('y2: ' + str(int(top + height)) + '\t')
            f.write(' name: ' + 'Stranger' + '\t')
            f.write(' gender: ' + gender)
            f.write('\n')

    # 画人脸搜索框(M:N)
    for face in faces:
        # print('face')
        # 取信息
        location = face['location']
        user_id = face['user_id']
        gender = user_id.split(' ')[-1]
        user_id = user_id.split(' ')[0]

        score = face['score']

        left = location['left']
        top = location['top']
        width = location['width']
        height = location['height']

        color = random_color()
        cv2.rectangle(src_img, (int(left), int(top)), (int(left + width), int(top + height)), color=color,
                      thickness=2)  # boxes
        cv2.rectangle(src_img, (int(left), int(top + height)), (int(left + width), int(top + height + 40)),
                      color=color,
                      thickness=-1)  # 底色
        cv2.putText(src_img, user_id, (int(left), int(top + height + 16)), fontFace=cv2.FONT_HERSHEY_TRIPLEX,
                    fontScale=0.5, color=(0, 0, 0), thickness=1)  # user_id
        cv2.putText(src_img, gender, (int(left), int(top + height + 32)), fontFace=cv2.FONT_HERSHEY_TRIPLEX,
                    fontScale=0.5, color=(0, 0, 0), thickness=1)  # user_id
        # cv2.putText(src_img, str(round(score, 2)), (int(left), int(top+height)),
        #             fontFace=cv2.FONT_HERSHEY_SIMPLEX, fontScale=1,
        #             color=color, thickness=2)  # score

        # 保存txt
        with open(prediction_txt_path, 'a', encoding='utf-8') as f:
            f.write('boxes:\t')
            f.write('x1: ' + str(int(left)) + '\t')
            f.write('y1: ' + str(int(top)) + '\t')
            f.write('x2: ' + str(int(left + width)) + '\t')
            f.write('y2: ' + str(int(top + height)) + '\t')
            f.write('name: ' + user_id + '\t')
            f.write('gender: ' + gender + '\t')
            f.write('\n')

    # 显示图片,保存图片
    print('Prediction Done!\n')
    cv2.imshow('img', src_img)
    img_save_path = os.path.join(prediction_root, '1' + 'PREDICTION_img.jpg')
    print('Save to' + img_save_path + '\n')
    cv2.imwrite(img_save_path, src_img, [int(cv2.IMWRITE_PNG_COMPRESSION), 0])

    print('Press any key to continue\n')
    print('Press ESC to exit\n')
    cv2.waitKey()
    cv2.destroyAllWindows()


if __name__ == '__main__':
    import argparse

    parser = argparse.ArgumentParser(description=__doc__)

    parser.add_argument('-n', '--num', default=[0, 0, 0, 0, 0, 0, 0, 0], type=int, metavar='N1, N2, N3...',
                        help='num of epochs', required=True, nargs='+')
    parser.add_argument('-m', '--model', help='model_name', required=True)  # 模型名称



    # parser.add_argument('-f', '--face', default=True, type=bool, help='model_name')

    # parser = argparse.ArgumentParser(description=__doc__)

    # parser.add_argument('--data-path', default='/datasets01/COCO/022719/', help='dataset')
    # parser.add_argument('--dataset', default='coco', help='dataset')
    # parser.add_argument('--model', default='maskrcnn_resnet50_fpn', help='model')
    # parser.add_argument('--device', default='cuda', help='device')
    # parser.add_argument('-b', '--batch-size', default=2, type=int,
    #                     help='images per gpu, the total batch size is $NGPU x batch_size')
    # parser.add_argument('--epochs', default=26, type=int, metavar='N',
    #                     help='number of total epochs to run')
    # parser.add_argument('-j', '--workers', default=4, type=int, metavar='N',
    #                     help='number of data loading workers (default: 4)')
    # parser.add_argument('--lr', default=0.02, type=float,
    #                     help='initial learning rate, 0.02 is the default value for training '
    #                     'on 8 gpus and 2 images_per_gpu')
    # parser.add_argument('--momentum', default=0.9, type=float, metavar='M',
    #                     help='momentum')
    # parser.add_argument('--wd', '--weight-decay', default=1e-4, type=float,
    #                     metavar='W', help='weight decay (default: 1e-4)',
    #                     dest='weight_decay')
    # parser.add_argument('--lr-step-size', default=8, type=int, help='decrease lr every step-size epochs')
    # parser.add_argument('--lr-steps', default=[16, 22], nargs='+', type=int, help='decrease lr every step-size epochs')
    # parser.add_argument('--lr-gamma', default=0.1, type=float, help='decrease lr by a factor of lr-gamma')
    # parser.add_argument('--print-freq', default=20, type=int, help='print frequency')
    # parser.add_argument('--output-dir', default='.', help='path where to save')
    # parser.add_argument('--resume', default='', help='resume from checkpoint')
    # parser.add_argument('--start_epoch', default=0, type=int, help='start epoch')
    # parser.add_argument('--aspect-ratio-group-factor', default=3, type=int)
    # parser.add_argument(
    #     "--test-only",
    #     dest="test_only",
    #     help="Only test the model",
    #     action="store_true",
    # )
    # parser.add_argument(
    #     "--pretrained",
    #     dest="pretrained",
    #     help="Use pre-trained models from the modelzoo",
    #     action="store_true",
    # )

    # distributed training parameters
    # parser.add_argument('--world-size', default=1, type=int,
    #                     help='number of distributed processes')
    # parser.add_argument('--dist-url', default='env://', help='url used to set up distributed training')

    args = parser.parse_args()

    # if args.output_dir:
    #     utils.mkdir(args.output_dir)

    # 将这些命令行参数传入主函数中运行
    main(args)
